{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "from typing import *\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from functools import partial\n",
    "from overrides import overrides\n",
    "\n",
    "from allennlp.data import Instance\n",
    "from allennlp.data.token_indexers import TokenIndexer\n",
    "from allennlp.data.tokenizers import Token\n",
    "from allennlp.nn import util as nn_util"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Config(dict):\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        for k, v in kwargs.items():\n",
    "            setattr(self, k, v)\n",
    "    \n",
    "    def set(self, key, val):\n",
    "        self[key] = val\n",
    "        setattr(self, key, val)\n",
    "        \n",
    "config = Config(\n",
    "    testing=True,\n",
    "    seed=1,\n",
    "    batch_size=64,\n",
    "    lr=3e-4,\n",
    "    epochs=2,\n",
    "    hidden_sz=64,\n",
    "    max_seq_len=100, # necessary to limit memory usage\n",
    "    max_vocab_size=100000,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from allennlp.common.checks import ConfigurationError"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "USE_GPU = torch.cuda.is_available()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_ROOT = Path(\"../data\") / \"jigsaw\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Set random seed manually to replicate results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x113989690>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.manual_seed(config.seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from allennlp.data.vocabulary import Vocabulary\n",
    "from allennlp.data.dataset_readers import DatasetReader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "label_cols = [\"toxic\", \"severe_toxic\", \"obscene\",\n",
    "              \"threat\", \"insult\", \"identity_hate\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from allennlp.data.fields import TextField, MetadataField, ArrayField\n",
    "\n",
    "class JigsawDatasetReader(DatasetReader):\n",
    "    def __init__(self, tokenizer: Callable[[str], List[str]]=lambda x: x.split(),\n",
    "                 token_indexers: Dict[str, TokenIndexer] = None,\n",
    "                 max_seq_len: Optional[int]=config.max_seq_len) -> None:\n",
    "        super().__init__(lazy=False)\n",
    "        self.tokenizer = tokenizer\n",
    "        self.token_indexers = token_indexers or {\"tokens\": SingleIdTokenIndexer()}\n",
    "        self.max_seq_len = max_seq_len\n",
    "\n",
    "    @overrides\n",
    "    def text_to_instance(self, tokens: List[Token], id: str=None,\n",
    "                         labels: np.ndarray=None) -> Instance:\n",
    "        sentence_field = TextField(tokens, self.token_indexers)\n",
    "        fields = {\"tokens\": sentence_field}\n",
    "        \n",
    "        id_field = MetadataField(id)\n",
    "        fields[\"id\"] = id_field\n",
    "        \n",
    "        if labels is None:\n",
    "            labels = np.zeros(len(label_cols))\n",
    "        label_field = ArrayField(array=labels)\n",
    "        fields[\"label\"] = label_field\n",
    "\n",
    "        return Instance(fields)\n",
    "    \n",
    "    @overrides\n",
    "    def _read(self, file_path: str) -> Iterator[Instance]:\n",
    "        df = pd.read_csv(file_path)\n",
    "        if config.testing: df = df.head(1000)\n",
    "        for i, row in df.iterrows():\n",
    "            yield self.text_to_instance(\n",
    "                [Token(x) for x in self.tokenizer(row[\"comment_text\"])],\n",
    "                row[\"id\"], row[label_cols].values,\n",
    "            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare token handlers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will use the spacy tokenizer here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "02/07/2019 17:37:35 - INFO - pytorch_pretrained_bert.tokenization -   loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /Users/keitakurita/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084\n"
     ]
    }
   ],
   "source": [
    "from allennlp.data.token_indexers import PretrainedBertIndexer\n",
    "\n",
    "token_indexer = PretrainedBertIndexer(\n",
    "    pretrained_model=\"bert-base-uncased\",\n",
    "    max_pieces=config.max_seq_len,\n",
    "    do_lowercase=True,\n",
    " )\n",
    "# apparently we need to truncate the sequence here, which is a stupid design decision\n",
    "def tokenizer(s: str):\n",
    "    return token_indexer.wordpiece_tokenizer(s)[:config.max_seq_len - 2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "reader = JigsawDatasetReader(\n",
    "    tokenizer=tokenizer,\n",
    "    token_indexers={\"tokens\": token_indexer}\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "267it [00:00, 295.56it/s]\n",
      "251it [00:01, 168.06it/s]\n"
     ]
    }
   ],
   "source": [
    "train_ds, test_ds = (reader.read(DATA_ROOT / fname) for fname in [\"train.csv\", \"test_proced.csv\"])\n",
    "val_ds = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "267"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<allennlp.data.instance.Instance at 0x1a2d3306a0>,\n",
       " <allennlp.data.instance.Instance at 0x1a2d335160>,\n",
       " <allennlp.data.instance.Instance at 0x1a2d336630>,\n",
       " <allennlp.data.instance.Instance at 0x1a2d339e80>,\n",
       " <allennlp.data.instance.Instance at 0x1a2d33f710>,\n",
       " <allennlp.data.instance.Instance at 0x1a2d33fe10>,\n",
       " <allennlp.data.instance.Instance at 0x1a2d342160>,\n",
       " <allennlp.data.instance.Instance at 0x1a2d342d68>,\n",
       " <allennlp.data.instance.Instance at 0x1a2d3494a8>,\n",
       " <allennlp.data.instance.Instance at 0x1a2d349a90>]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_ds[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'tokens': [[UNK],\n",
       "  [UNK],\n",
       "  the,\n",
       "  edit,\n",
       "  ##s,\n",
       "  made,\n",
       "  under,\n",
       "  my,\n",
       "  user,\n",
       "  ##name,\n",
       "  [UNK],\n",
       "  [UNK],\n",
       "  [UNK],\n",
       "  were,\n",
       "  reverted,\n",
       "  ##?,\n",
       "  [UNK],\n",
       "  weren,\n",
       "  ##',\n",
       "  ##t,\n",
       "  van,\n",
       "  ##dal,\n",
       "  ##isms,\n",
       "  ##,,\n",
       "  just,\n",
       "  closure,\n",
       "  on,\n",
       "  some,\n",
       "  [UNK],\n",
       "  after,\n",
       "  [UNK],\n",
       "  voted,\n",
       "  at,\n",
       "  [UNK],\n",
       "  [UNK],\n",
       "  [UNK],\n",
       "  [UNK],\n",
       "  [UNK],\n",
       "  please,\n",
       "  don,\n",
       "  ##',\n",
       "  ##t,\n",
       "  remove,\n",
       "  the,\n",
       "  template,\n",
       "  from,\n",
       "  the,\n",
       "  talk,\n",
       "  page,\n",
       "  since,\n",
       "  [UNK],\n",
       "  retired,\n",
       "  now,\n",
       "  ##.,\n",
       "  ##8,\n",
       "  ##9,\n",
       "  ##.,\n",
       "  ##20,\n",
       "  ##5,\n",
       "  ##.,\n",
       "  ##38,\n",
       "  ##.,\n",
       "  ##27],\n",
       " '_token_indexers': {'tokens': <allennlp.data.token_indexers.wordpiece_indexer.PretrainedBertIndexer at 0x1a2cbac128>},\n",
       " '_indexed_tokens': None,\n",
       " '_indexer_name_to_indexed_token': None}"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vars(train_ds[0].fields[\"tokens\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare vocabulary"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We don't need to build the vocab: all that is handled by the token indexer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab = Vocabulary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare iterator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The iterator is responsible for batching the data and preparing it for input into the model. We'll use the BucketIterator that batches text sequences of smilar lengths together."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "from allennlp.data.iterators import BucketIterator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "iterator = BucketIterator(batch_size=config.batch_size, \n",
    "                          sorting_keys=[(\"tokens\", \"num_tokens\")],\n",
    "                         )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We need to tell the iterator how to numericalize the text data. We do this by passing the vocabulary to the iterator. This step is easy to forget so be careful! "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "iterator.index_with(vocab)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Read sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch = next(iter(iterator(train_ds)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'tokens': {'tokens': tensor([[  101,   100,   100,  ...,   102,     0,     0],\n",
       "          [  101,   100,  1997,  ...,  1997,   100,   102],\n",
       "          [  101,  1000,   100,  ...,  2062,  2411,   102],\n",
       "          ...,\n",
       "          [  101,   100,   100,  ..., 29625,   100,   102],\n",
       "          [  101,   100,  1997,  ..., 12064,  2135,   102],\n",
       "          [  101,   100,  5993,  ...,  2089,  2421,   102]]),\n",
       "  'tokens-offsets': tensor([[ 1,  2,  3,  ..., 96,  0,  0],\n",
       "          [ 1,  2,  3,  ..., 96, 97, 98],\n",
       "          [ 1,  2,  3,  ..., 96, 97, 98],\n",
       "          ...,\n",
       "          [ 1,  2,  3,  ..., 96, 97, 98],\n",
       "          [ 1,  2,  3,  ..., 96, 97, 98],\n",
       "          [ 1,  2,  3,  ..., 96, 97, 98]]),\n",
       "  'mask': tensor([[1, 1, 1,  ..., 1, 0, 0],\n",
       "          [1, 1, 1,  ..., 1, 1, 1],\n",
       "          [1, 1, 1,  ..., 1, 1, 1],\n",
       "          ...,\n",
       "          [1, 1, 1,  ..., 1, 1, 1],\n",
       "          [1, 1, 1,  ..., 1, 1, 1],\n",
       "          [1, 1, 1,  ..., 1, 1, 1]])},\n",
       " 'id': ['003a19c04c079bf7',\n",
       "  '0098257c6952c9e3',\n",
       "  '0048de0c9422f64f',\n",
       "  '001d874a4d3e8813',\n",
       "  '00822d0d01752c7e',\n",
       "  '00328eadb85b3010',\n",
       "  '0006f16e4e9f292e',\n",
       "  '0082b5a7b4a67da2',\n",
       "  '006854d70298693e',\n",
       "  '006774d59329b7bd',\n",
       "  '004b975fabbbffa9'],\n",
       " 'label': tensor([[0., 0., 0., 0., 0., 0.],\n",
       "         [0., 0., 0., 0., 0., 0.],\n",
       "         [0., 0., 0., 0., 0., 0.],\n",
       "         [0., 0., 0., 0., 0., 0.],\n",
       "         [0., 0., 0., 0., 0., 0.],\n",
       "         [0., 0., 0., 0., 0., 0.],\n",
       "         [0., 0., 0., 0., 0., 0.],\n",
       "         [1., 0., 0., 0., 0., 0.],\n",
       "         [0., 0., 0., 0., 0., 0.],\n",
       "         [0., 0., 0., 0., 0., 0.],\n",
       "         [0., 0., 0., 0., 0., 0.]])}"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[  101,   100,   100,  ...,   102,     0,     0],\n",
       "        [  101,   100,  1997,  ...,  1997,   100,   102],\n",
       "        [  101,  1000,   100,  ...,  2062,  2411,   102],\n",
       "        ...,\n",
       "        [  101,   100,   100,  ..., 29625,   100,   102],\n",
       "        [  101,   100,  1997,  ..., 12064,  2135,   102],\n",
       "        [  101,   100,  5993,  ...,  2089,  2421,   102]])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch[\"tokens\"][\"tokens\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([11, 100])"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch[\"tokens\"][\"tokens\"].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prepare Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "from allennlp.modules.seq2vec_encoders import Seq2VecEncoder, PytorchSeq2VecWrapper\n",
    "from allennlp.nn.util import get_text_field_mask\n",
    "from allennlp.models import Model\n",
    "from allennlp.modules.text_field_embedders import TextFieldEmbedder\n",
    "\n",
    "class BaselineModel(Model):\n",
    "    def __init__(self, word_embeddings: TextFieldEmbedder,\n",
    "                 encoder: Seq2VecEncoder,\n",
    "                 out_sz: int=len(label_cols)):\n",
    "        super().__init__(vocab)\n",
    "        self.word_embeddings = word_embeddings\n",
    "        self.encoder = encoder\n",
    "        self.projection = nn.Linear(self.encoder.get_output_dim(), out_sz)\n",
    "        self.loss = nn.BCEWithLogitsLoss()\n",
    "        \n",
    "    def forward(self, tokens: Dict[str, torch.Tensor],\n",
    "                id: Any, label: torch.Tensor) -> torch.Tensor:\n",
    "        mask = get_text_field_mask(tokens)\n",
    "        embeddings = self.word_embeddings(tokens)\n",
    "        state = self.encoder(embeddings, mask)\n",
    "        class_logits = self.projection(state)\n",
    "        \n",
    "        output = {\"class_logits\": class_logits}\n",
    "        output[\"loss\"] = self.loss(class_logits, label)\n",
    "\n",
    "        return output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "02/07/2019 17:37:42 - INFO - pytorch_pretrained_bert.modeling -   loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz from cache at /Users/keitakurita/.pytorch_pretrained_bert/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba\n",
      "02/07/2019 17:37:42 - INFO - pytorch_pretrained_bert.modeling -   extracting archive file /Users/keitakurita/.pytorch_pretrained_bert/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba to temp dir /var/folders/hy/1czs1y5j2d58zgkqx6w_wnpw0000gn/T/tmpm99yyacp\n",
      "02/07/2019 17:37:48 - INFO - pytorch_pretrained_bert.modeling -   Model config {\n",
      "  \"attention_probs_dropout_prob\": 0.1,\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.1,\n",
      "  \"hidden_size\": 768,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 3072,\n",
      "  \"max_position_embeddings\": 512,\n",
      "  \"num_attention_heads\": 12,\n",
      "  \"num_hidden_layers\": 12,\n",
      "  \"type_vocab_size\": 2,\n",
      "  \"vocab_size\": 30522\n",
      "}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder\n",
    "from allennlp.modules.token_embedders.bert_token_embedder import PretrainedBertEmbedder\n",
    "\n",
    "bert_embedder = PretrainedBertEmbedder(\n",
    "        pretrained_model=\"bert-base-uncased\",\n",
    "        top_layer_only=True, # conserve memory\n",
    ")\n",
    "word_embeddings: TextFieldEmbedder = BasicTextFieldEmbedder({\"tokens\": bert_embedder},\n",
    "                                                            # we'll be ignoring masks so we'll need to set this to True\n",
    "                                                           allow_unmatched_keys = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "BERT_DIM = word_embeddings.get_output_dim()\n",
    "\n",
    "class BertSentencePooler(Seq2VecEncoder):\n",
    "    def forward(self, embs: torch.tensor, \n",
    "                mask: torch.tensor=None) -> torch.tensor:\n",
    "        # extract first token tensor\n",
    "        return embs[:, 0]\n",
    "    \n",
    "    @overrides\n",
    "    def get_output_dim(self) -> int:\n",
    "        return BERT_DIM\n",
    "    \n",
    "encoder = BertSentencePooler(vocab)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice how simple and modular the code for initializing the model is. All the complexity is delegated to each component."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = BaselineModel(\n",
    "    word_embeddings, \n",
    "    encoder, \n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "if USE_GPU: model.cuda()\n",
    "else: model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Basic sanity checks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch = nn_util.move_to_device(batch, 0 if USE_GPU else -1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokens = batch[\"tokens\"]\n",
    "labels = batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'tokens': tensor([[  101,   100,   100,  ...,   102,     0,     0],\n",
       "         [  101,   100,  1997,  ...,  1997,   100,   102],\n",
       "         [  101,  1000,   100,  ...,  2062,  2411,   102],\n",
       "         ...,\n",
       "         [  101,   100,   100,  ..., 29625,   100,   102],\n",
       "         [  101,   100,  1997,  ..., 12064,  2135,   102],\n",
       "         [  101,   100,  5993,  ...,  2089,  2421,   102]]),\n",
       " 'tokens-offsets': tensor([[ 1,  2,  3,  ..., 96,  0,  0],\n",
       "         [ 1,  2,  3,  ..., 96, 97, 98],\n",
       "         [ 1,  2,  3,  ..., 96, 97, 98],\n",
       "         ...,\n",
       "         [ 1,  2,  3,  ..., 96, 97, 98],\n",
       "         [ 1,  2,  3,  ..., 96, 97, 98],\n",
       "         [ 1,  2,  3,  ..., 96, 97, 98]]),\n",
       " 'mask': tensor([[1, 1, 1,  ..., 1, 0, 0],\n",
       "         [1, 1, 1,  ..., 1, 1, 1],\n",
       "         [1, 1, 1,  ..., 1, 1, 1],\n",
       "         ...,\n",
       "         [1, 1, 1,  ..., 1, 1, 1],\n",
       "         [1, 1, 1,  ..., 1, 1, 1],\n",
       "         [1, 1, 1,  ..., 1, 1, 1]])}"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1, 1, 1,  ..., 1, 0, 0],\n",
       "        [1, 1, 1,  ..., 1, 1, 1],\n",
       "        [1, 1, 1,  ..., 1, 1, 1],\n",
       "        ...,\n",
       "        [1, 1, 1,  ..., 1, 1, 1],\n",
       "        [1, 1, 1,  ..., 1, 1, 1],\n",
       "        [1, 1, 1,  ..., 1, 1, 1]])"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mask = get_text_field_mask(tokens)\n",
    "mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.4729,  0.2816,  0.4229, -0.0508,  0.2860,  0.1277],\n",
       "        [ 0.4567,  0.1642,  0.1731, -0.0892,  0.0180, -0.0179],\n",
       "        [ 0.4481,  0.1152,  0.1730, -0.2053, -0.1605, -0.1381],\n",
       "        [ 0.3364,  0.1633,  0.1099, -0.1427, -0.0838,  0.0581],\n",
       "        [ 0.3577,  0.1493,  0.0643, -0.1915, -0.0968, -0.1219],\n",
       "        [ 0.2557,  0.1389,  0.2810, -0.1334, -0.0666, -0.0243],\n",
       "        [ 0.2137,  0.1822,  0.2200, -0.1423, -0.0362, -0.1285],\n",
       "        [ 0.5856,  0.1423,  0.3972, -0.0657,  0.0962,  0.1120],\n",
       "        [ 0.3897,  0.2418,  0.1570, -0.2679,  0.0226, -0.1083],\n",
       "        [ 0.1805, -0.0639, -0.0843, -0.0652, -0.0810, -0.0309],\n",
       "        [ 0.2778,  0.1103,  0.3753,  0.0957,  0.4034, -0.0487]],\n",
       "       grad_fn=<AddmmBackward>)"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embeddings = model.word_embeddings(tokens)\n",
    "state = model.encoder(embeddings, mask)\n",
    "class_logits = model.projection(state)\n",
    "class_logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'class_logits': tensor([[ 0.4487,  0.2121,  0.3770, -0.0024, -0.0786, -0.1584],\n",
       "         [ 0.3425,  0.1704,  0.2082, -0.1272, -0.0710, -0.2600],\n",
       "         [ 0.3889,  0.2092,  0.3206, -0.0456, -0.0397, -0.1992],\n",
       "         [ 0.3932, -0.0668,  0.1999, -0.0447,  0.2420,  0.0691],\n",
       "         [ 0.3591,  0.2240,  0.1109, -0.1412, -0.1365, -0.1329],\n",
       "         [ 0.2854,  0.1720,  0.0455, -0.0622, -0.1601,  0.0505],\n",
       "         [ 0.1628,  0.0641,  0.1516, -0.0599, -0.0546, -0.0176],\n",
       "         [ 0.4083,  0.0808,  0.1382, -0.1770, -0.0371,  0.0616],\n",
       "         [ 0.3897,  0.2182, -0.0965, -0.2455, -0.0945, -0.0205],\n",
       "         [ 0.1571,  0.0584,  0.1086, -0.1595, -0.0177, -0.1803],\n",
       "         [ 0.1892,  0.1309,  0.4160, -0.0852,  0.1782, -0.2983]],\n",
       "        grad_fn=<AddmmBackward>),\n",
       " 'loss': tensor(0.7260, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)}"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model(**batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss = model(**batch)[\"loss\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.7254, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss.backward()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[]"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[x.grad for x in list(model.encoder.parameters())]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = optim.Adam(model.parameters(), lr=config.lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "from allennlp.training.trainer import Trainer\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    optimizer=optimizer,\n",
    "    iterator=iterator,\n",
    "    train_dataset=train_ds,\n",
    "    cuda_device=0 if USE_GPU else -1,\n",
    "    num_epochs=config.epochs,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "02/07/2019 17:38:09 - INFO - allennlp.training.trainer -   Beginning training.\n",
      "02/07/2019 17:38:09 - INFO - allennlp.training.trainer -   Epoch 0/1\n",
      "02/07/2019 17:38:09 - INFO - allennlp.training.trainer -   Peak CPU memory usage MB: 1287.417856\n",
      "02/07/2019 17:38:09 - INFO - allennlp.training.trainer -   Training\n",
      "loss: 0.7036 ||: 100%|██████████| 5/5 [01:28<00:00, 18.71s/it]\n",
      "02/07/2019 17:39:38 - INFO - allennlp.training.trainer -                     Training |  Validation\n",
      "02/07/2019 17:39:38 - INFO - allennlp.training.trainer -   loss          |     0.704  |       N/A\n",
      "02/07/2019 17:39:38 - INFO - allennlp.training.trainer -   cpu_memory_MB |  1287.418  |       N/A\n",
      "02/07/2019 17:39:38 - INFO - allennlp.training.trainer -   Epoch duration: 00:01:29\n",
      "02/07/2019 17:39:38 - INFO - allennlp.training.trainer -   Estimated training time remaining: 0:01:29\n",
      "02/07/2019 17:39:38 - INFO - allennlp.training.trainer -   Epoch 1/1\n",
      "02/07/2019 17:39:38 - INFO - allennlp.training.trainer -   Peak CPU memory usage MB: 1870.708736\n",
      "02/07/2019 17:39:39 - INFO - allennlp.training.trainer -   Training\n",
      "loss: 0.6123 ||: 100%|██████████| 5/5 [00:50<00:00, 10.44s/it]\n",
      "02/07/2019 17:40:29 - INFO - allennlp.training.trainer -                     Training |  Validation\n",
      "02/07/2019 17:40:29 - INFO - allennlp.training.trainer -   loss          |     0.612  |       N/A\n",
      "02/07/2019 17:40:29 - INFO - allennlp.training.trainer -   cpu_memory_MB |  1870.709  |       N/A\n",
      "02/07/2019 17:40:29 - INFO - allennlp.training.trainer -   Epoch duration: 00:00:50\n"
     ]
    }
   ],
   "source": [
    "metrics = trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generating Predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "from allennlp.data.iterators import DataIterator\n",
    "from tqdm import tqdm\n",
    "from scipy.special import expit # the sigmoid function\n",
    "\n",
    "def tonp(tsr): return tsr.detach().cpu().numpy()\n",
    "\n",
    "class Predictor:\n",
    "    def __init__(self, model: Model, iterator: DataIterator,\n",
    "                 cuda_device: int=-1) -> None:\n",
    "        self.model = model\n",
    "        self.iterator = iterator\n",
    "        self.cuda_device = cuda_device\n",
    "        \n",
    "    def _extract_data(self, batch) -> np.ndarray:\n",
    "        out_dict = self.model(**batch)\n",
    "        return expit(tonp(out_dict[\"class_logits\"]))\n",
    "    \n",
    "    def predict(self, ds: Iterable[Instance]) -> np.ndarray:\n",
    "        pred_generator = self.iterator(ds, num_epochs=1, shuffle=False)\n",
    "        self.model.eval()\n",
    "        pred_generator_tqdm = tqdm(pred_generator,\n",
    "                                   total=self.iterator.get_num_batches(ds))\n",
    "        preds = []\n",
    "        with torch.no_grad():\n",
    "            for batch in pred_generator_tqdm:\n",
    "                batch = nn_util.move_to_device(batch, self.cuda_device)\n",
    "                preds.append(self._extract_data(batch))\n",
    "        return np.concatenate(preds, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "from allennlp.data.iterators import BasicIterator\n",
    "# iterate over the dataset without changing its order\n",
    "seq_iterator = BasicIterator(batch_size=64)\n",
    "seq_iterator.index_with(vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:46<00:00,  8.34s/it]\n",
      "100%|██████████| 4/4 [00:44<00:00, 11.04s/it]\n"
     ]
    }
   ],
   "source": [
    "predictor = Predictor(model, seq_iterator, cuda_device=0 if USE_GPU else -1)\n",
    "train_preds = predictor.predict(train_ds) \n",
    "test_preds = predictor.predict(test_ds) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "celltoolbar": "Tags",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
